{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import metrics\n",
    "\n",
    "from alpaca.ue import MCDUE\n",
    "from alpaca.utils.datasets.builder import build_dataset\n",
    "from alpaca.utils.ue_metrics import ndcg, classification_metric\n",
    "from alpaca.ue.masks import BasicBernoulliMask, DecorrelationMask\n",
    "import alpaca.nn as ann\n",
    "from alpaca.utils import model_builder\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = build_dataset('mnist', val_size=10000)\n",
    "x_train, y_train = mnist.dataset('train')\n",
    "x_val, y_val = mnist.dataset('val')\n",
    "x_shape = (-1, 1, 28, 28)\n",
    "\n",
    "train_ds = TensorDataset(torch.FloatTensor(x_train.reshape(x_shape)), torch.LongTensor(y_train))\n",
    "val_ds = TensorDataset(torch.FloatTensor(x_val.reshape(x_shape)), torch.LongTensor(y_val))\n",
    "train_loader = DataLoader(train_ds, batch_size=256)\n",
    "val_loader = DataLoader(val_ds, batch_size=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleConv(nn.Module):\n",
    "    def __init__(self, num_classes=10, activation=None):\n",
    "        if activation is None:\n",
    "            self.activation = F.leaky_relu\n",
    "        else:\n",
    "            self.activation = activation\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, 3)\n",
    "        self.conv2 = nn.Conv2d(16, 32, 3)\n",
    "        self.linear_size = 12*12*32\n",
    "        self.fc1 = nn.Linear(self.linear_size, 256)\n",
    "        self.dropout = ann.Dropout()\n",
    "        self.fc2 = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.activation(self.conv1(x))\n",
    "        x = self.activation(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, self.linear_size)\n",
    "        x = self.activation(self.fc1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SimpleConv()\n",
    "model = model_builder.build_model(model, dropout_rate=0.5, dropout_mask=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 235/235 [00:36<00:00,  6.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Train loss on last batch 0.04960529878735542\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "for x_batch, y_batch in tqdm(train_loader):\n",
    "    prediction = model(x_batch)\n",
    "    optimizer.zero_grad()\n",
    "    loss = criterion(prediction, y_batch)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "print('\\nTrain loss on last batch', loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_builder.build_model(model, dropout_rate=0.5, dropout_mask=BasicBernoulliMask())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Uncertainty estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Uncertainty estimation with MCDUE_classification approach: 100%|██████████| 25/25 [00:02<00:00, 12.36it/s]\n"
     ]
    }
   ],
   "source": [
    "# Calculate uncertainty estimation\n",
    "x_batch, y_batch = next(iter(val_loader))\n",
    "\n",
    "estimator = MCDUE(model, num_classes=10, acquisition=\"bald\")\n",
    "predictions, estimations = estimator(x_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "score = classification_metric(estimations, np.equal(predictions.argmax(axis=-1), y_batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'Accuracy')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEGCAYAAACtqQjWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de5SeZX3v//cnM5PDTI5zSIAkM5OQKAYaAwYELSVitUFbELQSfvWAPxW7K7p7oBWWa2N3fmWh/ljLFqG6UpsKLgVsumtDGxrcgRSs4CYICYc0MATIicPkTM7J5Lv/uK8JN8MkeZJ5nnmeZ+bzWutec9/XfbquNWG+XIf7uhQRmJmZFcOQcmfAzMwGDgcVMzMrGgcVMzMrGgcVMzMrGgcVMzMrmtpyZ6Ccmpubo729vdzZMDOrKo8//vjmiGjp7dygDirt7e2sWLGi3NkwM6sqkl4+2jk3f5mZWdE4qJiZWdE4qJiZWdE4qJiZWdE4qJiZWdGUNKhIWijpdUlPH+W8JN0qqUPSKknn5M59VtLzaftsLv09kp5K99wqSSm9UdLP0/U/lzSulGUzM7O3K3VN5YfA3GOcvwSYnrZrgO9BFiCAbwDvBc4DvpELEt8Dvpi7r/v51wPLImI6sCwdm5lZPyrpdyoR8ZCk9mNcchlwZ2Tz7z8qaaykU4E5wM8jYiuApJ8DcyUtB0ZHxKMp/U7gY8B96Vlz0nPvAJYDXytuiTJrXn2Df1u1qRSPNhu0GobV8v/+5hTqatwqX83K/fHjRGB97nhDSjtW+oZe0gEmRMQraf9VYEJvL5R0DVmtiNbW1pPKdMfru/jugx0nda+ZvV33sk6/MXEM75vWXN7MWJ+UO6iURESEpF5XH4uIBcACgNmzZ5/UCmUfnXkqH5350T7k0MzyNm7fy/u/+QDrtu7hfeXOjPVJueuZG4HJueNJKe1Y6ZN6SQd4LTWdkX6+XqI8m1mRnTJ6OHU1Yt3WPeXOivVRuYPKYuAzaRTY+cCO1IS1FPiwpHGpg/7DwNJ0bqek89Oor88A/5J7Vvcosc/m0s2swtUMERPHjnBQGQBK2vwl6S6yzvNmSRvIRnTVAUTE94ElwEeADmAP8Ll0bquk/w94LD1qfnenPfBHZKPKRpB10N+X0r8J/FTS54GXgU+WsmxmVlyTG+tZ76BS9Uo9+uuq45wP4MtHObcQWNhL+grgrF7StwAfPLmcmlm5tTbWs+SpV45/oVW0cjd/mZkBWVDZtucgO/cdLHdWrA8cVMysIrQ21gO4CazKOaiYWUWY7KAyIDiomFlFaG3KgopHgFU3BxUzqwijh9cxZkQd67fuLXdWrA8cVMysYrQ21rumUuUcVMysYrT6W5Wq56BiZhVjcmM9G7btpevwSU3LZxXAQcXMKkZrYz0Hug7z2s595c6KnSQHFTOrGN3fqrhfpXo5qJhZxXBQqX4OKmZWMU4dO5wh8geQ1cxBxcwqRl3NEE4bO8JBpYo5qJhZRfG3KtXNQcXMKkoWVPxVfbVyUDGzijK5sZ7Nu/az58ChcmfFTkJJg4qkuZLWSOqQdH0v59skLZO0StJySZNy574l6em0XZlLf1jSk2nbJOlnKX2OpB25czeWsmxmVhpvToHv2ko1KtnKj5JqgNuBDwEbgMckLY6IZ3OX3QLcGRF3SLoYuBn4tKSPAucAs4BhwHJJ90XEzoi4MPeOf+Kta9E/HBG/W6oymVnp5YcVv/OUUWXOjZ2oUtZUzgM6ImJtRBwA7gYu63HNDOCBtP9g7vwM4KGIOBQRu4FVwNz8jZJGAxcDPytR/s2sDPytSnUrZVCZCKzPHW9IaXkrgSvS/uXAKElNKX2upHpJzcAHgMk97v0YsCwidubSLpC0UtJ9ks4sVkHMrP+Mra9j5LBaDyuuUuXuqL8OuEjSE8BFwEagKyLuB5YAvwTuAh4Bunrce1U61+3XQFtEvBv4LkepwUi6RtIKSSs6OzuLWhgz6ztJTPZsxVWrlEFlI2+tXUxKaUdExKaIuCIizga+ntK2p583RcSsiPgQIOC57vtS7eU84N9yz9oZEbvS/hKgLl33FhGxICJmR8TslpaWIhXVzIqptXGEm7+qVCmDymPAdElTJA0F5gGL8xdIapbUnYcbgIUpvSY1gyFpJjATuD936yeAf42IfblnnSJJaf88srJtKUnJzKykuj+AjPAU+NWmZKO/IuKQpGuBpUANsDAinpE0H1gREYuBOcDNkgJ4CPhyur0OeDjFiJ3ApyIiP2h9HvDNHq/8BPDfJB0C9gLzwv8izapSa2M9+w8dpvON/YwfPbzc2bETULKgAkeaoZb0SLsxt78IWNTLffvIRoAd7blzekm7DbitD9k1swoxOTcCzEGlupS7o97M7G08rLh6OaiYWcWZOG4EkoNKNXJQMbOKM6y2hlNGD3dQqUIOKmZWkSY31rPB839VHQcVM6tIXlelOjmomFlFam2s59Wd+9h3sOdkGlbJHFTMrCJ1jwDbsM1NYNXEQcXMKtLkI+uquAmsmjiomFlF8rcq1clBxcwqUvPIoYyoq3FQqTIOKmZWkbIp8D1bcbVxUDGzitXqdVWqjoOKmVWs7sW6POF49XBQMbOK1dpYz+4DXWzdfaDcWbEClXTqezOzvugeAfaX9z5L7RCxfc8Bduw9yPa9B9l/8DDfuXIW501pLHMuLc81FTOrWGeeNoax9XX8x5rXWfHyVjbvOkD90FredcpoNu3Yyy9f2FzuLFoPrqmYWcU6Zcxwnrzxw72ee9/Ny1i3xZ34laakNRVJcyWtkdQh6fpezrdJWiZplaTlkiblzn1L0tNpuzKX/kNJL0p6Mm2zUrok3ZretUrSOaUsm5mVV2tTPS97ZFjFKVlQkVQD3A5cQrY08FWSei4RfAtwZ0TMBOYDN6d7PwqcA8wC3gtcJ2l07r4/j4hZaXsypV0CTE/bNcD3SlMyM6sEbY0NvOyaSsUpZU3lPKAjItZGxAHgbuCyHtfMAB5I+w/mzs8AHoqIQxGxG1gFzD3O+y4jC1AREY8CYyWdWoyCmFnlaW2qZ/Ou/ezef6jcWbGcUgaVicD63PGGlJa3Ergi7V8OjJLUlNLnSqqX1Ax8AJicu++m1MT1HUnDTuB9SLpG0gpJKzo7O0+2bGZWZm1N2cgw11YqS7lHf10HXCTpCeAiYCPQFRH3A0uAXwJ3AY8A3Ysq3ACcAZwLNAJfO5EXRsSCiJgdEbNbWlqKUwoz63dtjQ0ArNu6u8w5sbxSBpWNvLV2MSmlHRERmyLiiog4G/h6Stueft6U+kw+BAh4LqW/kpq49gP/QNbMVtD7zGzgaHVNpSKVMqg8BkyXNEXSUGAesDh/gaRmSd15uAFYmNJrUjMYkmYCM4H70/Gp6aeAjwFPp/sXA59Jo8DOB3ZExCslLJ+ZldGYEXWMra/zCLAKU7LvVCLikKRrgaVADbAwIp6RNB9YERGLgTnAzZICeAj4crq9Dng4ixvsBD4VEd29cT+W1EJWe3kS+MOUvgT4CNAB7AE+V6qymVllaGtq8LcqFaakHz9GxBKyP/b5tBtz+4uARb3ct49sBFhvz7z4KOnBm0HJzAaBtsZ6nli/rdzZsJxyd9SbmZ20tqZ6Nm3fx8Guw+XOiiUOKmZWtVob6+k6HGzctrfcWbHEQcXMqlZbUzas2J31lcNBxcyqVvcHkOu2+FuVSuGgYmZVa/yoYQyvG8JLHgFWMRxUzKxqSaK1sd4fQFYQBxUzq2ptTQ2eqqWCOKiYWVVra6xn3dY9ZJ+qWbk5qJhZVWtrqmffwcO8/sb+cmfFcFAxsyrX2j2s2P0qFcFBxcyqWltj92zF7lepBA4qZlbVJo4bQc0Qsc4fQFYEBxUzq2p1NUM4bexwN39VCAcVM6t6bY0Nbv6qEA4qZlb12prqPf9XhXBQMbOq19ZUz/Y9B9mx92C5szLolTSoSJoraY2kDknX93K+TdIySaskLZc0KXfuW5KeTtuVufQfp2c+LWmhpLqUPkfSDklPpu3Gnu8zs4GptTEbVuxVIMuvZEFFUg1wO3AJ2SqOV0nquZrjLcCdETETmA/cnO79KHAOMAt4L3CdpNHpnh8DZwC/AYwAvpB73sMRMStt80tTMjOrNN2zFb/s6VrKrpQ1lfOAjohYGxEHgLuBy3pcMwN4IO0/mDs/A3goIg5FxG5gFTAXsiWKIwH+DzAJMxvUWo98q+KaSrmVMqhMBNbnjjektLyVwBVp/3JglKSmlD5XUr2kZuADwOT8janZ69PAv+eSL5C0UtJ9ks7sLVOSrpG0QtKKzs7Oky2bmVWQhmG1NI8c5uavClDujvrrgIskPQFcBGwEuiLifmAJ8EvgLuARoKvHvX9LVpt5OB3/GmiLiHcD3wV+1tsLI2JBRMyOiNktLS1FL5CZlUc2AszNX+VWyqCykbfWLialtCMiYlNEXBERZwNfT2nb08+bUt/IhwABz3XfJ+kbQAvwp7ln7YyIXWl/CVCXajlmNgi0Nda7plIBjhtUJP2epJMJPo8B0yVNkTQUmAcs7vHs5tyzbwAWpvSa1AyGpJnATOD+dPwF4HeAqyLicO5Zp0hS2j8vlW3LSeTbzKpQa1M9r+zcx76DPRs1rD8VEiyuBJ6X9G1JZxT64Ig4BFwLLAVWAz+NiGckzZd0abpsDrBG0nPABOCmlF4HPCzpWWAB8Kn0PIDvp2sf6TF0+BPA05JWArcC88ILLJgNGu1NDUTAhm2urZRT7fEuiIhPpeG8VwE/lBTAPwB3RcQbx7l3CVnfSD7txtz+ImBRL/ftIxsB1tsze81zRNwG3Hbs0pjZQNXa9OYIsGnjR5U5N4NXQc1aEbGT7I//3cCpZCO1fi3pKyXMm5lZwdo8rLgiFNKncqmkfwaWkzVLnRcRlwDvBv6stNkzMytMY8NQRg6r9RT4ZXbc5i/g48B3IuKhfGJE7JH0+dJky8zsxEiitbHesxWXWSHNX39J9uU6AJJGSGoHiIhlJcmVmdlJ8GzF5VdIUPlH4HDuuCulmZlVlNamejZs3UvX4ZMf+BkRbNm1n5Xrt7Nl1/4i5m5wKKT5qzbN3QVARBxI352YmVWUtsYGDnQd5tWd+5g4dsQxr9257yDPbtrJs5t28tKW3WzYtpcN2/awYdte9hzIvnW5cHozP/r8e/sj6wNGIUGlU9KlEbEYQNJlwObSZsvM7MS1dw8r3rz7SFDZvf8Q67ftYd2WPax59Q2e2bSTZ1/Z+ZYO/dHDa5k0rp72pgZ+c1oLk8aN4H+vfo01rx7zqwnrRSFB5Q+BH0u6jWy6lPXAZ0qaKzOzk9D9rcq3/v2/kMT6rXvYsvvAW65pa6rnrImjufLcycw4bTRnnjaa8aOGv+1Zew928csXtrBr/yFGDivkT6VBYR8/vgCcL2lkOt5V8lyZmZ2EU8eM4KyJo9m25yCtjfV8+MwJTBpXz+TGeiaPG8Hp40cyenhdQc+a0pwt/PXS5t2cNXFMKbM9oBQUftOiWWcCw9P0WngRLDOrNDVDxL9+5cKiPKu9KQWVLQ4qJ6KQjx+/Tzb/11fImr9+H2grcb7MzMqqvTlrSntps797ORGFDCl+X0R8BtgWEf8TuAB4R2mzZWZWXvVDa5kwehgvbvZ3LyeikKCyL/3cI+k04CDZ/F9mZgNae1MDL/kL/RNSSFC5V9JY4P8nW13xJeAnpcyUmVklmNLc4OavE3TMjvq0gNaytBrjP0n6V2B4ROzol9yZmZXRlOYGtuw+wI69BxkzorBRY4PdMWsqaWXF23PH+x1QzGywaM8NK7bCFNL8tUzSx7uX6j0RkuZKWiOpQ9L1vZxvk7RM0ipJyyVNyp37lqSn03ZlLn2KpF+lZ97TPWWMpGHpuCOdbz/R/JqZ5R35VsX9KgUrJKh8iWwCyf2Sdkp6Q9LO490kqYaslnMJ2SqOV0nquZrjLcCdETETmA/cnO79KHAOMAt4L3BdWn0S4FtkU/FPA7YB3dPvf55shNo04DvpOjOzk9baWI8EL7qmUrDjBpWIGBURQyJiaESMTsejj3cfcB7QERFr04SUdwOX9bhmBvBA2n8wd34G8FBEHIqI3cAqYG6qLV3Mm0sQ3wF8LO1flo5J5z94MrUrM7Nuw+tqOG3MCDd/nYBCPn78rd62Ap49kWyesG4bUlreSuCKtH85MEpSU0qfK6leUjPwAWAy0ARsj4hDvTzzyPvS+R3p+p7luUbSCkkrOjs7CyiGmQ1m7c31vOgligtWyDQtf57bH05WA3mcrMbQV9cBt0m6GngI2Ah0RcT9ks4Ffgl0Ao+QrePSZxGxAFgAMHv27JNfdMHMBoX2pgb+ddUr5c5G1ShkQsnfyx9Lmgz8dQHP3khWu+g2KaXln72JVFNJE1Z+PA1fJiJuAm5K534CPAdsAcZKqk21kfwzu9+3QVItMCZdb2Z20qY0N7Bj70G27T7AuAYvJXU8hXTU97QBeFcB1z0GTE+jtYYC84DF+QskNadvYQBuABam9JrUDIakmcBM4P6ICLK+l0+kez4L/EvaX5yOSecfSNebmZ207hFga92vUpDj1lQkfRfo/uM8hGxE1q+Pd19EHJJ0LbAUqAEWRsQzkuYDK9KiX3OAmyUFWfPXl9PtdcDDqZ99J/CpXD/K14C7Jf0V8ATw9yn974EfSeoAtpIFMTOzPsl/q/KetnFlzk3lK6RPZUVu/xBwV0T8ZyEPj4glwJIeaTfm9hfx5kiu/DX7yEaA9fbMtWT9Or3d8/uF5MvMrFCTx9UzRP5WpVCFBJVFwL6I6IIjTVP1EeHhEGY24A2tHcKkcfX+VqVABX1RD4zIHY8A/ndpsmNmVnnamz1bcaEKCSrD80sIp/360mXJzKyyTGmq56XNe/DYn+MrJKjslnRO94Gk9wB7S5clM7PK0t7cwK79h9i860C5s1LxCulT+WPgHyVtIltO+BSy5YXNzAaF7mHFL27eTcuoYWXOTWUr5OPHxySdAbwzJa2JiIOlzZaZWeWYkhtWfN6UxjLnprIVMvfXl4GGiHg6Ip4GRkr6o9JnzcysMkwcO4LaIeJFd9YfVyF9Kl/snjoFICK2AV8sXZbMzCpLbc0QWhvrPVtxAQoJKjX5KeTTOimeAMfMBpX25gZ/q1KAQoLKvwP3SPqgpA8CdwH3lTZbZmaVpb2pgZe3eFjx8RQy+utrwDXAH6bjVWQjwMzMBo0pzfXsPdjFazv3c8qY4eXOTsUqZOXHw8CvgJfI5ty6GFhd2myZmVWWKc0jAS8tfDxHralIegdwVdo2A/cARMQH+idrZmaVo705m0jkxc27ueD0ty0qa8mxmr/+C3gY+N2I6ACQ9Cf9kiszswpz2pgRDK0d4jnAjuNYzV9XAK8AD0r6u9RJr2Ncb2Y2YA0ZItoaPVvx8Rw1qETEzyJiHnAG2WqLfwyMl/Q9SR/urwyamVWK9uYGf6tyHIV01O+OiJ+kteonka22+LWS58zMrMJMaW7g5a17OHzYw4qP5oTWqI+IbRGxICI+WMj1kuZKWiOpQ9L1vZxvk7RM0ipJyyVNyp37tqRnJK2WdKsyoyQ9mds2S/rrdP3Vkjpz575wImUzMzue9qYGDhw6zKYdfZuofdP2vfzT4xtY+IsXB9x3L4V8p3JS0pf3twMfAjYAj0laHBHP5i67BbgzIu6QdDFwM/BpSe8D3g/MTNf9ArgoIpYDs3LveBz4X7nn3RMR15aqTGY2uHWPAHtp8x4mjSt8WanON/bzyNotPPLCFh55YTMvbXlz4dw572xhasvIoue1XEoWVMi+aelIa8oj6W7gMiAfVGYAf5r2HwR+lvYDGE42HYyAOuC1/MPTkOfxZCPUzMxKbuqRb1V28ZvTm992/sChw7zQuYs1r77BmtfeyH6++gYbt2c1m1HDannv1EY+fUE7o4fX8ueLVrG2c7eDSoEmAutzxxuA9/a4ZiXZKLO/AS4HRklqiohHJD1INvpMwG0R0fODy3lkNZN83fHjkn4LeA74k4hY3+MeJF1DNkMAra2tJ104Mxt8Jowexoi6Gl7cnNU09h/q4ol12/nF85v5Rcdmnt64g0Opv6WuRpzeMpL3tI3j0xe0ccHUJs48bTS1NVmvw/Y92YJfA200WSmDSiGuA26TdDXwELAR6JI0DXgX2cAAgJ9LujAi8rWSecCnc8f3AndFxH5JXwLuIPv6/y0iYgGwAGD27NkDqzHTzEpKEm1N9Sxf8zprN+/iV2u3svdgFzVDxLsnjeGLvzWVM04ZxRmnjGZKcwNDa4/ebT22fiiNDUNZu3nXUa+pRqUMKhuBybnjSSntiIjYRFZTQdJI4OMRsV3SF4FHI2JXOncfcAGpqUvSu4HaiHg896wtuUf/APh20UtkZoPeWRPHsOjxDSD45OxJvH9aM+ef3sTo4XUn/KypzQ280OmaSqEeA6ZLmkIWTOYB/0/+AknNwNY0v9gNwMJ0ah3wRUk3kzV/XQT8de7Wq8hmS84/69SIeCUdXornJzOzEvirj53FX/zOOxk/uu+TSk5pbmD5c51FyFXlOKEhxSciIg4B1wJLyf7A/zQinpE0X9Kl6bI5wBpJzwETgJtS+iLgBeApsn6XlRFxb+7xn6RHUAG+moYgrwS+Clxd/FKZ2WA3vK6mKAEFYGrLSDrf2M8b+wbOCu0l7VOJiCXAkh5pN+b2F5EFkJ73dQFfOsZzp/aSdgNZbcfMrCpMbWkAYG3nbt49eWyZc1McJaupmJnZsU1tTkFlAHXWO6iYmZVJa1M9QwQvDqDOegcVM7MyGVZbw+TGel4YQN+qOKiYmZXR1OYG1rqmYmZmxTCleSQvbd49YGY+dlAxMyujqS0N7D3Yxas795U7K0XhoGJmVkb5YcUDgYOKmVkZ5Wc+HggcVMzMymjC6GE0DK0ZMHOAOaiYmZWRJKa0NLB2gAwrdlAxMyuzKc0jWdvp5i8zMyuCqc0NbNy+l30Hu8qdlT5zUDEzK7OpLQ1EwMu5teurlYOKmVmZdY8AGwhNYA4qZmZlNqX7W5UB0FnvoGJmVmYjh9UyYfSwAfEBZEmDiqS5ktZI6pB0fS/n2yQtk7RK0nJJk3Lnvp1Wclwt6VZJSunL0zOfTNv4lD5M0j3pXb+S1F7KspmZFdPU5pEDYl2VkgUVSTXA7cAlwAzgKkkzelx2C3BnRMwE5gM3p3vfB7wfmAmcBZxLtk59tz+IiFlpez2lfR7YFhHTgO8A3ypNyczMim9KSwMvuvnrmM4DOiJibUQcAO4GLutxzQzggbT/YO58AMOBocAwoA547Tjvuwy4I+0vAj7YXbsxM6t0U5sb2L7nIFt3Hyh3VvqklEFlIrA+d7whpeWtBK5I+5cDoyQ1RcQjZEHmlbQtjYjVufv+ITV9/Y9c4Djyvog4BOwAmnpmStI1klZIWtHZ2dm3EpqZFcnpLQNjBFi5O+qvAy6S9ARZ89ZGoEvSNOBdwCSyYHGxpAvTPX8QEb8BXJi2T5/ICyNiQUTMjojZLS0txSqHmVmfTGkeGCPAShlUNgKTc8eTUtoREbEpIq6IiLOBr6e07WS1lkcjYldE7ALuAy5I5zemn28APyFrZnvL+yTVAmOALaUpmplZcU0aN4K6GlX9CLBSBpXHgOmSpkgaCswDFucvkNQsqTsPNwAL0/46shpMraQ6slrM6nTcnO6tA34XeDrdsxj4bNr/BPBARAyMpdTMbMCrrRlCW1ODm7+OJvVrXAssBVYDP42IZyTNl3RpumwOsEbSc8AE4KaUvgh4AXiKrN9lZUTcS9Zpv1TSKuBJstrJ36V7/h5oktQB/CnwtiHMZmaVbEpz9c9WXFvKh0fEEmBJj7Qbc/uLyAJIz/u6gC/1kr4beM9R3rUP+P0+ZtnMrGymtjSwfM3rdB0OaoZU5+DVcnfUm5lZcnrzSA52BRu2Ve/Ekg4qZmYVYkoR16vveH0XP3r0ZTbv2t/nZ52IkjZ/mZlZ4abmhhV/4ATvPXw4WLVxB0ufeZX7n3n1yPLE23Yf4KsfnF7knB6dg4qZWYVobBjKmBF1BY0Aiwhe2rKHJ9dvY8VL21i2+nVe3bmPmiHi/KmNfOaCdr7/Hy/w/Ov9O5rMQcXMrEJIykaA9dL8te9gF796cSu/fnkbT67fzsoN29m+5yAA9UNruHB6M39x5ju5+IzxjK0fCsB/PNfJ86+90a9lcFAxM6sgU1sa+GVH9t32voNdLF/TyZKnXmHZ6tfYfaALCd4xfhS/M+MUZrWOZdbksUwfP5Lamrd3kU8bP5JfdGzu19FkDipmZhXk9JaR/K9fb+SPfvw4y9d0sudAF+Pq6/i9d5/G75x1Cue2NzJyWGF/uqe1jOTAocNs2LaHtqaGEuc846BiZlZBZpw2GoD/8+JWLj97Ih/5jVN575TGXmsixzNtQjZJ5fOv7XJQMTMbjOa8o4Vlf3YR7U0NfW6ymjY+Cyodnbv4bSYUI3vH5aBiZlZBJB2ZBr+vRg+vY/yoYTz/Wv+NAPPHj2ZmA9j0CSPp6MdJKh1UzMwGsGktI3nh9V3016TtDipmZgPYtPEj2bX/EK/u3Ncv73NQMTMbwKaNHwVkc4H1BwcVM7MB7MgIMAcVMzPrq+aRQxlbX9dvc4CVNKhImitpjaQOSW9biVFSm6RlklZJWi5pUu7ctyU9I2m1pFuVqZf0b5L+K537Zu76qyV1SnoybV8oZdnMzKqBJKa1jKz+moqkGuB24BJgBnCVpBk9LrsFuDMiZgLzgZvTve8D3g/MBM4CziVbpx7glog4AzgbeL+kS3LPuyciZqXtByUqmplZVZk2fgAEFeA8oCMi1kbEAeBu4LIe18wAHkj7D+bOBzAcGEq2Ln0d8FpE7ImIBwHSM38NTMLMzI5q2viRbN19gK27D5T8XaUMKhOB9bnjDSktbyVwRdq/HBglqSkiHiELMq+kbWlErM7fKGks8HvAslzyx1NT2iJJk3vLlKRrJK2QtKKzs/Nky2ZmVjX6s7O+3B311wEXSXqCrHlrI9AlaRrwLrJayETgYkkXdt8kqRa4C7g1Itam5HuB9tSU9nPgjt5eGBELImJ2RMxuaWkpVbnMzCpGd1B5/vXSr61SyqCyEcjXFialtCMiYm4B870AAAkiSURBVFNEXBERZwNfT2nbyWotj0bErojYBdwHXJC7dQHwfET8de5ZWyKiezHmHwDvKXaBzMyq0WljRlA/tKbqayqPAdMlTZE0FJgHLM5fIKlZUncebgAWpv11ZDWYWkl1ZLWY1emevwLGAH/c41mn5g4v7b7ezGywGzIkm6SyqoNKRBwCrgWWkv2B/2lEPCNpvqRL02VzgDWSngMmADel9EXAC8BTZP0uKyPi3jTk+OtkHfy/7jF0+KtpmPFK4KvA1aUqm5lZtemvEWAlnfo+IpYAS3qk3ZjbX0QWQHre1wV8qZf0DUCvCwxExA1ktR0zM+th2viR/PMTG3lj30FGDa8r2XvK3VFvZmb9oLuz/oXO3SV9j4OKmdkgML2fhhU7qJiZDQKtjfUMrRnioGJmZn1XWzOE9uZ6Okr8rYqDipnZIDF9/CjXVMzMrDhOHz+SdVv3sO9gV8ne4aBiZjZITBs/ksMBL24u3QgwBxUzs0GiP0aAOaiYmQ0SU5obGCJKugqkg4qZ2SAxvK6G1sZ6XnBQMTOzYij1HGAOKmZmg8jp40eydvMuDnUdLsnzHVTMzAaR6eNHcbArWLd1T0me76BiZjaIvLkKZGmawBxUzMwGkdNbGoDSDSt2UDEzG0RGDa/jslmnMWnciJI8v6SLdJmZWeX5m3lnl+zZJa2pSJoraY2kDknX93K+TdIySaskLU/LBXef+3ZaHni1pFslKaW/R9JT6Zn59EZJP5f0fPo5rpRlMzOztytZUJFUA9wOXEK2pvxVkmb0uOwW4M6ImAnMB25O974PeD8wEzgLOBe4KN3zPeCLwPS0zU3p1wPLImI6sCwdm5lZPyplTeU8oCMi1kbEAeBu4LIe18wAHkj7D+bOBzAcGAoMA+qA1ySdCoyOiEcjIoA7gY+ley4D7kj7d+TSzcysn5QyqEwE1ueON6S0vJXAFWn/cmCUpKaIeIQsyLyStqURsTrdv+Eoz5wQEa+k/VeBCb1lStI1klZIWtHZ2XlyJTMzs16Ve/TXdcBFkp4ga97aCHRJmga8C5hEFjQulnRhoQ9NtZg4yrkFETE7Ima3tLT0uQBmZvamUgaVjcDk3PGklHZERGyKiCsi4mzg6yltO1mt5dGI2BURu4D7gAvS/ZOO8szu5jHSz9eLXyQzMzuWUgaVx4DpkqZIGgrMAxbnL5DULKk7DzcAC9P+OrIaTK2kOrJazOrUvLVT0vlp1NdngH9J9ywGPpv2P5tLNzOzflKyoBIRh4BrgaXAauCnEfGMpPmSLk2XzQHWSHqOrA/kppS+CHgBeIqs32VlRNybzv0R8AOgI11zX0r/JvAhSc8Dv52OzcysHynrfhicJHUCL5/k7c3A5iJmpxq4zIODyzw49KXMbRHRa6f0oA4qfSFpRUTMLnc++pPLPDi4zINDqcpc7tFfZmY2gDiomJlZ0TionLwF5c5AGbjMg4PLPDiUpMzuUzEzs6JxTcXMzIrGQcXMzIrGQeU4ClgT5mpJnZKeTNsXypHPYjpemdM1n5T0bFrz5if9ncdiK+D3/J3c7/g5SdvLkc9iKqDMrZIelPREWvPoI+XIZzH1ZY2naiRpoaTXJT19lPNK61J1pDKf0+eXRoS3o2xADdlX+1PJpuFfCczocc3VwG3lzms/l3k68AQwLh2PL3e+S13mHtd/BVhY7nz3w+95AfDf0v4M4KVy57sfyvyPwGfT/sXAj8qd7z6W+beAc4Cnj3L+I2Szkgg4H/hVX9/pmsqxFbImzEBTSJm/CNweEdsAIqLaJ+880d/zVcBd/ZKz0imkzAGMTvtjgE39mL9S6MsaT1UpIh4Cth7jksvIFkqMiHgUGNs9Me/JclA5tkLWhAH4eKo6LpI0uZfz1aSQMr8DeIek/5T0qKS5VLdCf89IagOm8OYfnmpVSJn/EviUpA3AErIaWjU76TWe+iFv5VLwv/1COaj03b1Ae2RLIv+cN1efHMhqyZrA5pD9X/vfSRpb1hz1n3nAoojoKndG+sFVwA8jYhJZM8mPcrOKD1S9rvFU3ixVl4H+D6SvClkTZktE7E+HPwDe0095K5Xjlpns/2YWR8TBiHgReI4syFSrQsrcbR7V3/QFhZX588BPASJbjXU42SSE1aovazwNVCfyb78gDirHVsiaMPn2x0vJpvmvZsctM/AzsloKkprJmsPW9mcmi6yQMiPpDGAc8Eg/568UCinzOuCDAJLeRRZUqnkN7r6s8TRQLQY+k0aBnQ/siDeXZT8ptcXJ18AUEYckda8JU0M24ucZSfOBFRGxGPhqWh/mEFmH2NVly3ARFFjmpcCHJT1L1jTw5xGxpXy57psCywzZH6G7Iw2bqWYFlvnPyJo2/4Ss0/7qai57gWWeA9wsKYCHgC+XLcNFIOkusjI1p76xbwB1ABHxfbK+so+QrU+1B/hcn99Zxf9GzMyswrj5y8zMisZBxczMisZBxczMisZBxczMisZBxczMisZBxayIJHXlZjN+UlJ7H583Kz87sKRLjzZztFkl8JBisyKStCsiRh7lnMj+mzt8As+7GpgdEdcWKYtmJeWgYlZEPYNKqqksBX5FNoXPR4DrgXOBEWTziH0jXXsu8DdAA7Af+BDwVLpuI3Bz2p8dEdemZy8kmzqlE/hcRKyT9ENgJzAbOAX4i4hYVMJimx3h5i+z4hqRa/r655Q2HfjbiDgzIl4Gvh4Rs4GZZJMXzkzThtwD/PeIeDfw28Bu4EbgnoiYFRH39HjXd4E70mSmPwZuzZ07FfhN4HeBb5aorGZv42lazIprb0TM6j5ItYmX01oV3T4p6Rqy//5OJVvDI4BXIuIxgIjYme4/1rsu4M1p2n8EfDt37mepme1ZSRP6UiCzE+GgYlZ6u7t3JE0hm1793IjYlpqqhpfgnftz+8eMTGbF5OYvs/41mizI7Eg1iEtS+hrg1NSvgqRRkmqBN4BRR3nWL8kmuQT4A+DhkuXarEAOKmb9KCJWAk8A/wX8BPjPlH4AuBL4rqSVZAu+DSdb0nZG6qO5ssfjvgJ8TtIq4NPAf++fUpgdnUd/mZlZ0bimYmZmReOgYmZmReOgYmZmReOgYmZmReOgYmZmReOgYmZmReOgYmZmRfN/AVw9Xaxjm204AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(*score)\n",
    "plt.xlabel(\"Fraction\")\n",
    "plt.ylabel(\"Accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try different mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "model = model_builder.build_model(model, dropout_rate=0.5, dropout_mask=DecorrelationMask())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Uncertainty estimation with MCDUE_classification approach: 100%|██████████| 25/25 [00:01<00:00, 18.27it/s]\n"
     ]
    }
   ],
   "source": [
    "x_batch, y_batch = next(iter(val_loader))\n",
    "# Calculate uncertainty estimation\n",
    "estimator = MCDUE(model, num_classes=10, acquisition=\"bald\")\n",
    "predictions, estimations = estimator(x_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "score = classification_metric(estimations, np.equal(predictions.argmax(axis=-1), y_batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'Accuracy')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de5RddX338fcnc8nkMpeQDGQuuYDGhkkyCRAotlUgVle8gUKr0CriaqFWsX2eFqvUVe1Dy6K1rmXVWn3QRsRVAc3zqPFpaLQYRCtQYkNuhIQISDIZkgkhkxuZZCbf54+9JxyGSXIGzj7nzJnPa61Z2WffzvdHLh9++7d/eysiMDMzy9e4UhdgZmaji4PDzMxGxMFhZmYj4uAwM7MRcXCYmdmIVJe6gGKYNm1azJ49u9RlmJmNKr/4xS/2RETz0PVjIjhmz57NmjVrSl2GmdmoIulXw633pSozMxsRB4eZmY2Ig8PMzEbEwWFmZiPi4DAzsxHJNDgkLZO0W9LGk2yXpC9I2iZpvaTzc7Z9QNIT6c8HctZfIGlDeswXJCnLNpiZ2Utl3eO4A1h6iu1vBeakPzcAXwaQdAbwaeDXgYuAT0uakh7zZeD6nONOdX4zMyuwTOdxRMQDkmafYpcrgDsjebb7Q5KaJLUAlwI/ioi9AJJ+BCyVdD/QEBEPpevvBN4F3JtF/d9du4Oneg5lcWozG2JuSwNvW9BS6jIsD6WeANgGbM/5vCNdd6r1O4ZZ/zKSbiDpxTBz5sxXVNwP1nWzesvuV3SsmeUvAmqrxvHb555FbbWHXstdqYMjMxFxO3A7wOLFi1/R26qWXXdhQWsys+GtWLeTP7lrLb/sOci5LQ2lLsdOo9TR3gXMyPncnq471fr2Ydab2SjWkYbFYzv3l7gSy0epg2MFcG16d9XFQG9EdAOrgLdImpIOir8FWJVu2y/p4vRuqmuB75esejMriLOnTaKuZhyPdTs4RoNML1VJuotkoHuapB0kd0rVAETEV4CVwNuAbcBh4IPptr2S/gZ4JD3VLYMD5cCHSe7WmkAyKJ7JwLiZFU/VOPFr0xvc4xglsr6r6prTbA/gIyfZtgxYNsz6NcD8ghRoZmWjo6WBlRu6iQg8Pau8lfpSlZkZAB2tDfS+cIydvUdKXYqdhoPDzMrC4AD5Zl+uKnsODjMrC3On1yPhAfJRwMFhZmVh0vhqzp46yQPko4CDw8zKxrktDe5xjAIODjMrGx2tDTyz9zD7jxwrdSl2Cg4OMysbgwPkj3cfKHEldioODjMrGx2tg48e6S1xJXYqDg4zKxtn1o9n6qRaNrvHUdYcHGZWNiTR0eoB8nLn4DCzsnJuSwNbdh3g2MDxUpdiJ+HgMLOy0tHSwNH+4zzpt2+WLQeHmZWVEwPk3R4gL1cODjMrK+dMm0Rt9TjPIC9jDg4zKyvVVeOYO73eA+RlzMFhZmWno6WBzd0HSF7ZY+XGwWFmZaejtYG9h46ya39fqUuxYTg4zKzsnNviAfJy5uAws7Izd3o9gAfIy5SDw8zKTn1dDbOmTvQAeZlycJhZWepoaXCPo0w5OMysLHW0NPD0c4c52Ndf6lJsCAeHmZWlwRnkj/tyVdlxcJhZWRoMjs0OjrLj4DCzsjS9oY6miTUeIC9D1aUuwMxsOJJODJAfOTbAvsPH2PfC0eTXw0fpfeEYh/oGOHy0n0NHBzjcl/xaUzWOT72jgwm1VaVuQsVycJhZ2epoaeBrP3uKuX/176fcr3qcmDS+mpqqcew52MdbOs7isrlnFqnKsSfT4JC0FPg8UAV8LSL+bsj2WcAyoBnYC7wvInak2/4eeHu6699ExD3p+juAS4DBKaXXRcSjWbbDzErj2tfPprZ6HJPrqmmaUEvTxJrkZ0ItDROqmTy+mom11dRWJ1fdD/b1M//Tq9jQ1evgyFBmwSGpCvgS8GZgB/CIpBUR8VjObp8F7oyIb0haAtwGvF/S24HzgUXAeOB+SfdGxODFzo9FxPKsajez8jBz6kT+YuncvPefPL6ac6ZNYmOXH1WSpSwHxy8CtkXEkxFxFLgbuGLIPh3Aj9Pl1TnbO4AHIqI/Ig4B64GlGdZqZhViXlsjmzxxMFNZBkcbsD3n8450Xa51wJXp8ruBeklT0/VLJU2UNA24DJiRc9ytktZL+pyk8cN9uaQbJK2RtKanp6cQ7TGzUWBBWwNd+15g76GjpS6lYpX6dtybgEskrSUZt+gCBiLih8BK4OfAXcCDwEB6zM3AXOBC4Azg48OdOCJuj4jFEbG4ubk521aYWdmY39oI4MtVGcoyOLp4aS+hPV13QkTsjIgrI+I84JPpun3pr7dGxKKIeDMgYGu6vjsSfcDXSS6JmZkBMG8wOHY6OLKSZXA8AsyRdLakWuBqYEXuDpKmSRqs4WaSO6yQVJVeskJSJ9AJ/DD93JL+KuBdwMYM22Bmo0zjxBpmnDGBTV0e58hKZndVRUS/pBuBVSS34y6LiE2SbgHWRMQK4FLgNkkBPAB8JD28Bvhpkg3sJ7lNd/BJZ/8qqZmkF/Io8KGs2mBmo9OCtkb3ODKU6TyOiFhJMlaRu+5TOcvLgZfdVhsRR0jurBrunEsKXKaZVZh5rY2s3PAsvS8co3FCTanLqTilHhw3Myu4+W3JOMcm9zoy4eAws4ozP32yrsc5suHgMLOKM3XyeFob6zzOkREHh5lVpHltjWzwXI5MODjMrCLNb23kqT2H/OrZDDg4zKwizW9rIMJvEMyCg8PMKtLgnVV+9EjhOTjMrCKd1VBHc/14j3NkwMFhZhVrfmuDb8nNgIPDzCrW/LZGnth9gBeODpx+Z8ubg8PMKta81kaOBzz+rHsdheTgMLOKtaB98BHrDo5CcnCYWcVqbaxjysQaNu7wAHkhOTjMrGJJYr4fsV5wDg4zq2jzWhvZuusAff0eIC8UB4eZVbQFbY0cGwie2HWw1KVUDAeHmVW0+W3JI9Y9EbBwHBxmVtFmnjGR+rpqP3qkgBwcZlbRJDGvtcG35BZQpu8cNzMrB/NbG7nzoV9xbOA4NVXJ/y/3Dxxn7+Gj9BzoY8fzL7B97+Hk5/kXeGbvYXbtP8JnrurkrQtaSlx9+XFwmFnFW9DeyNH+41x9+0McPNLPnoN97D18lIiX7lc/vpr2MybymuZJ7Oo9wn2P73ZwDMPBYWYV7/WvmcrCdBb5rKkTuWD2FKZNHk/z5FqmTh5P+5QJzJgykaaJNUgC4INf/y/W79hXyrLLloPDzCremfV1fP/G3xrRMZ3tTfxkaw+H+vqZNN7/VOby4LiZ2TAWzkgekLjJg+ov4+AwMxvGgrYmAF+uGoaDw8xsGM3142ltrGOdH5D4Mg4OM7OT6GxvYoN7HC+TaXBIWippi6Rtkj4xzPZZku6TtF7S/ZLac7b9vaSN6c97c9afLenh9Jz3SKrNsg1mNnYtaG/k6ecO03v4WKlLKSuZBYekKuBLwFuBDuAaSR1DdvsscGdEdAK3ALelx74dOB9YBPw6cJOkhvSYvwc+FxGvBZ4H/iCrNpjZ2LawPR3n6HKvI1eWPY6LgG0R8WREHAXuBq4Ysk8H8ON0eXXO9g7ggYjoj4hDwHpgqZIbrJcAy9P9vgG8K8M2mNkYtqAtmfux3uMcL5FlcLQB23M+70jX5VoHXJkuvxuolzQ1Xb9U0kRJ04DLgBnAVGBfRPSf4pxmZgXROLGG2VMn+s6qIUo9OH4TcImktcAlQBcwEBE/BFYCPwfuAh4ERvQWFkk3SFojaU1PT0+ByzazsaKzvck9jiGyDI4ukl7CoPZ03QkRsTMiroyI84BPpuv2pb/eGhGLIuLNgICtwHNAk6Tqk50z59y3R8TiiFjc3NxcyHaZ2RjS2d5Id+8Rdh84UupSykaWwfEIMCe9C6oWuBpYkbuDpGmSBmu4GViWrq9KL1khqRPoBH4YEUEyFvI76TEfAL6fYRvMbIzrTAfIN7jXcUJmwZGOQ9wIrAI2A9+OiE2SbpF0ebrbpcAWSVuBs4Bb0/U1wE8lPQbcDrwvZ1zj48CfSdpGMubxL1m1wcxsflsD44QnAubI9MldEbGSZKwid92ncpaX8+IdUrn7HCG5s2q4cz5JcseWmVnmJtZWM+fMek8EzFHqwXEzs7K3oL2R9Tt6iaEv8BijThsckt6ZMw5hZjbmLGxv5LlDR+na90KpSykL+QTCe4EnJH1G0tysCzIzKzceIH+p0wZHRLwPOA/4JXCHpAfTORL1mVdnZlYG5rbUU1MlD5Cn8roEFRH7SQax7wZaSGZ5/7ekj2ZYm5lZWRhfXcXc6Q2eQZ7KZ4zjcknfBe4nuU32ooh4K7AQ+PNsyzMzKw+d7Y1s6Orl+HEPkOfT47iK5Gm0CyLiHyJiN0BEHMZPpjWzMaKzvZEDR/p5+rlDpS6l5PIJjr8G/mvwg6QJkmYDRMR9mVRlZlZmBgfI/dyq/ILjO8DxnM8D6TozszFjzpmTqasZ5+Agv+CoTt+nAUC67LfumdmYUl01jnmtjR4gJ7/g6Ml5thSSrgD2ZFeSmVl56mxvZOPOXvoHjp9+5wqWT3B8CPhLSc9I2k7ykME/yrYsM7Pys7C9iSPHjvPE7oOlLqWkTvuQw4j4JXCxpMnp57H9X8zMxqwF7cmrZDfs6OXcloYSV1M6eT0dV9LbgXlAXfLab4iIWzKsy8ys7Jw9dRL1ddWs27GP91w44/QHVKh8JgB+heR5VR8leRPf7wKzMq7LzKzsjBsnFrQ1jvk7q/IZ4/iNiLgWeD4i/hfweuB12ZZlZlaeOtubePzZ/fT1D5S6lJLJJzgGX7R7WFIrcIzkeVVmZmPOwvZGjg0Em7sPlLqUksknOH4gqQn4B+C/gaeBb2VZlJlZueqcMTiDfOzO5zjl4Hj6Aqf7ImIf8H8k/T+gLiLG9gU+MxuzWhvrmDa5lnXbe5ML92PQKXscEXEc+FLO5z6HhpmNZZLobG8a0z2OfC5V3SfpKg3eh2tmNsZ1tjeyrecgB/v6S11KSeQTHH9E8lDDPkn7JR2QtD/juszMytbC9iYixu6rZPN5dWx9RIyLiNqIaEg/j90pk2Y25nWmM8jH6uWq084cl/TG4dZHxAOFL8fMrPxNnTye9ikTxuxEwHweOfKxnOU64CLgF8CSTCoyMxsFFrY3sc49juFFxDtzP0uaAfxjZhWZmY0Cne2N/NuGbp472MfUyeNLXU5R5TM4PtQO4NxCF2JmNpqceJVs19i7XJXPGMcXgUg/jgMWkcwgNzMbsxa0NyLB+u29XPZrZ5a6nKLKZ4xjTc5yP3BXRPxnPieXtBT4PFAFfC0i/m7I9lnAMqAZ2Au8LyJ2pNs+A7ydJKx+BPxpRISk+0melfVCepq3RMTufOoxMyuUyeOreU3z5DF5Z1U+wbEcOBIRAwCSqiRNjIjDpzpIUhXJrPM3k1zeekTSioh4LGe3zwJ3RsQ3JC0BbgPeL+k3gN8EOtP9fgZcAtyffv79iMgNNDOzoutsb+SBrXuICMbSHOm8Zo4DE3I+TwD+I4/jLgK2RcSTEXEUuBu4Ysg+HcCP0+XVOduD5A6uWmA8UAPsyuM7zcyKZmF7E3sO9tHde+T0O1eQfIKjLvd1senyxDyOawO253zeka7LtQ64Ml1+N1AvaWpEPEgSJN3pz6qI2Jxz3NclPSrpr072KBRJN0haI2lNT09PHuWamY3MWJ0ImE9wHJJ0/uAHSRfw4vjCq3UTcImktSSXorqAAUmvJblzq50kbJZIekN6zO9HxALgDenP+4c7cUTcHhGLI2Jxc3Nzgco1M3vRuS0NVI8T68bYRMB8xjj+B/AdSTtJXh07neRVsqfTBeS+lLc9XXdCROwk7XFImgxcFRH7JF0PPDTY05F0L8kDjH8aEV3psQckfYvkktidedRjZlZQdTVVzG2pd49jqIh4BJgL/DHwIeDciPhFHud+BJgj6WxJtcDVwIrcHSRNS9/5AXAzyR1WAM+Q9ESqJdWQ9EY2p5+npcfWAO8ANuZRi5lZJpJHrPdy/HicfucKcdrgkPQRYFJEbIyIjcBkSR8+3XER0Q/cCKwCNgPfjohNkm6RdHm626XAFklbgbOAW9P1y4FfAhtIxkHWRcQPSAbKV0laDzxK0oP5at6tNTMrsIXtjRw40s9Tzx0qdSlFk8+lqusjIvdlTs+nl5L++XQHRsRKYOWQdZ/KWV5OEhJDjxsgeZz70PWHgAvyqNnMrChOzCDfsY/XNE8ucTXFkc/geFXunUvp/Iza7EoyMxs95pw5mbqaccmrZMeIfHoc/w7cI+l/p5//CLg3u5LMzEaP6qpxzG9tHFMD5Pn0OD5OMknvQ+nPBl46IdDMbEzrbG9i0879HBs4XupSiiKfu6qOAw8DT5Pc+rqEZLDbzMyAhTMa6es/ztZdB0pdSlGc9FKVpNcB16Q/e4B7ACLisuKUZmY2Orw4QN7LvNbGEleTvVP1OB4n6V28IyJ+KyK+CAwUpywzs9Fj9tSJNE6oGTPjHKcKjitJnhO1WtJXJb2JZOa4mZnlkERneyNrnxnjwRER34uIq0lmja8mefTImZK+LOktxSrQzGw0OG9GE1t3HeBQX3+pS8lcPoPjhyLiW+m7x9uBtSR3WpmZWWrRzCaOB2wYA6+SHdE7xyPi+fSps2/KqiAzs9FoYTpA/uj2yr9cNaLgMDOz4U2dPJ5ZUyfy6BgY53BwmJkVyKIZTazd/nypy8icg8PMrEAWzWhi1/4+unsL9a678uTgMDMrkEUz0nGOCr9c5eAwMyuQjtYGaqvGVfwAuYPDzKxAxldXcW5rA2sdHGZmlq/zZjSxYUcv/RX8pFwHh5lZAZ03s4kXjg2wddfBUpeSGQeHmVkBnRggr+DLVQ4OM7MCmnnGRM6YVMvaZyp3PoeDw8ysgCSxsL3RPQ4zM8vfohlT2NZzkANHjpW6lEw4OMzMCuy8mU1EJG8ErEQODjOzAltY4QPkDg4zswJrnFDDOc2TKvaNgA4OM7MMLJrRxKPb9xERpS6l4BwcZmYZOG9GE3sO9tG1r/KelJtpcEhaKmmLpG2SPjHM9lmS7pO0XtL9ktpztn1G0iZJmyV9QZLS9RdI2pCe88R6M7NysmjGFIDMLldFBNt2H+RffvYU//gfW4vas6nO6sSSqoAvAW8GdgCPSFoREY/l7PZZ4M6I+IakJcBtwPsl/Qbwm0Bnut/PgEuA+4EvA9cDDwMrgaXAvVm1w8zslZjbUs/46uRJue9c2FqQcx44coyf//I5frK1h59s6XlJb+Z3F8+grWlCQb7ndDILDuAiYFtEPAkg6W7gCiA3ODqAP0uXVwPfS5cDqANqAQE1wC5JLUBDRDyUnvNO4F04OMyszNRUjWNB2yufCHj4aD+buw+waWcvm7r2s6m7l8e7D9B/PJhUW8VvvnYaH77sNYyTuPn/buDZ3hcqIjjagO05n3cAvz5kn3XAlcDngXcD9ZKmRsSDklYD3STB8U8RsVnS4vQ8uedsG+7LJd0A3AAwc+bMAjTHzGxkFs1o4psP/YpjA8epqTr1yMChvn5+tm0P92/ZzX89tZcn9xxi8OrTlIk1zGtt5Po3nsMb5zRzwawp1FYn59vy7AEAdu47wgWzMm3OCVkGRz5uAv5J0nXAA0AXMCDptcC5wOCYx48kvQHIe5QpIm4HbgdYvHhx5d3WYGZlb9HMJr72s6d4vPsAC9obX7b96T2H+PHju1m9ZTcPP7mXowPHmTy+movPOYN3LmxlXmsj81obaGms42TDudMb6wB4tvdIpm3JlWVwdAEzcj63p+tOiIidJD0OJE0GroqIfZKuBx6KiIPptnuB1wPf5MUwGfacZmbl4sUn5T5Pa1MdG7p62djVm/66/8QYxTnNk7j29bNYMvdMFs8+40RvIh8NddVMqq1iZxHfc55lcDwCzJF0Nsk/7lcDv5e7g6RpwN6IOA7cDCxLNz0DXC/pNpJLVZcA/xgR3ZL2S7qYZHD8WuCLGbbBzOwVa2uawLTJ4/mbf9vMX31/04n1s6dO5LyZTfzhG85mydwzmTV10iv+DklMb6yrjB5HRPRLuhFYBVQByyJik6RbgDURsQK4FLhNUpBcqvpIevhyYAmwgWSg/N8j4gfptg8DdwATSAbFPTBuZmVJEn/yptey5unnWdDWyLy2Bua3NdJQV1PQ72ltmsDOIgaHKnFW41CLFy+ONWvWlLoMM7NMfOw763jgiR4e/svfLuh5Jf0iIhYPXe+Z42Zmo1xLYx27D/RxrEjvOXdwmJmNci1NE4iA3Qf6ivJ9Dg4zs1HuxVtyi3NnlYPDzGyUa21MZozv3FecAXIHh5nZKFfsSYAODjOzUa7YkwAdHGZmo1yxJwE6OMzMKkBr0wS6HRxmZpav6Q11dPtSlZmZ5auYkwAdHGZmFaCYkwAdHGZmFaCYkwAdHGZmFaCYkwAdHGZmFaCYkwAdHGZmFWBwEmAxbsl1cJiZVYDBSYDFuCXXwWFmViGKNQnQwWFmViGKNQnQwWFmViGKNQnQwWFmViGKNQnQwWFmViGKNQnQwWFmViEGJwFmPUDu4DAzqxCDPY7ujGePOzjMzCpEsSYBOjjMzCpEsSYBOjjMzCpIMSYBOjjMzCpIMSYBZhockpZK2iJpm6RPDLN9lqT7JK2XdL+k9nT9ZZIezfk5Iuld6bY7JD2Vs21Rlm0wMxtNijEJMLPgkFQFfAl4K9ABXCOpY8hunwXujIhO4BbgNoCIWB0RiyJiEbAEOAz8MOe4jw1uj4hHs2qDmdloMzgJsCfDSYBZ9jguArZFxJMRcRS4G7hiyD4dwI/T5dXDbAf4HeDeiDicWaVmZhXixC25GV6uyjI42oDtOZ93pOtyrQOuTJffDdRLmjpkn6uBu4asuzW9vPU5SeOH+3JJN0haI2lNT0/PK2uBmdkoU4xJgKUeHL8JuETSWuASoAsYGNwoqQVYAKzKOeZmYC5wIXAG8PHhThwRt0fE4ohY3NzcnFH5ZmblpRiTAKszO3MSAjNyPren606IiJ2kPQ5Jk4GrImJfzi7vAb4bEcdyjulOF/skfZ0kfMzMjOJMAsyyx/EIMEfS2ZJqSS45rcjdQdI0SYM13AwsG3KOaxhymSrthSBJwLuAjRnUbmY2KhVjEmBmwRER/cCNJJeZNgPfjohNkm6RdHm626XAFklbgbOAWwePlzSbpMfykyGn/ldJG4ANwDTgb7Nqg5nZaJT1JMAsL1URESuBlUPWfSpneTmw/CTHPs3LB9OJiCWFrdLMrLJMb6jjiV17Mjt/qQfHzcyswJJJgEfoz2gSoIPDzKzCtDRN4HiGbwJ0cJiZVZisJwE6OMzMKkzWkwAdHGZmFSbrSYAODjOzCpP1JEAHh5lZhRmcBPjsfo9xmJlZnlqbJrDTl6rMzCxf0xvqeNaXqszMLF9ZTgJ0cJiZVaAsJwE6OMzMKtCvTa/n7Z0tHI8o+LkzfcihmZmVxvkzp3D+703J5NzucZiZ2Yg4OMzMbEQcHGZmNiIODjMzGxEHh5mZjYiDw8zMRsTBYWZmI+LgMDOzEVFkMKuw3EjqAX71Cg+fBuwpYDmjgds8NrjNle/VtndWRDQPXTkmguPVkLQmIhaXuo5icpvHBre58mXVXl+qMjOzEXFwmJnZiDg4Tu/2UhdQAm7z2OA2V75M2usxDjMzGxH3OMzMbEQcHGZmNiIOjpSkpZK2SNom6RPDbL9OUo+kR9OfPyxFnYV0ujan+7xH0mOSNkn6VrFrLKQ8fo8/l/P7u1XSvlLUWUh5tHmmpNWS1kpaL+ltpaizkPJo8yxJ96XtvV9SeynqLCRJyyTtlrTxJNsl6Qvpf5P1ks5/VV8YEWP+B6gCfgmcA9QC64COIftcB/xTqWstcpvnAGuBKennM0tdd5btHbL/R4Flpa67CL/HtwN/nC53AE+Xuu4itPk7wAfS5SXAN0tddwHa/UbgfGDjSba/DbgXEHAx8PCr+T73OBIXAdsi4smIOArcDVxR4pqylk+brwe+FBHPA0TE7iLXWEgj/T2+BrirKJVlJ582B9CQLjcCO4tYXxbyaXMH8ON0efUw20ediHgA2HuKXa4A7ozEQ0CTpJZX+n0OjkQbsD3n84503VBXpd285ZJmFKe0zOTT5tcBr5P0n5IekrS0aNUVXr6/x0iaBZzNi/+4jFb5tPmvgfdJ2gGsJOlpjWb5tHkdcGW6/G6gXtLUItRWSnn/+c+HgyN/PwBmR0Qn8CPgGyWupxiqSS5XXUryf+BfldRU0oqK42pgeUQMlLqQIrgGuCMi2kkuZ3xTUqX/u3ATcImktcAlQBcwFn6vC6bS/4DkqwvI7UG0p+tOiIjnIqIv/fg14IIi1ZaV07aZ5P9KVkTEsYh4CthKEiSjUT7tHXQ1o/8yFeTX5j8Avg0QEQ8CdSQPxhut8vm7vDMiroyI84BPputG/Y0QpzGSP/+n5eBIPALMkXS2pFqSfzhW5O4w5Hrg5cDmItaXhdO2GfgeSW8DSdNILl09WcwiCyif9iJpLjAFeLDI9WUhnzY/A7wJQNK5JMHRU9QqCyufv8vTcnpVNwPLilxjKawArk3vrroY6I2I7ld6surC1TV6RUS/pBuBVSR3ZSyLiE2SbgHWRMQK4E8kXQ70kwxCXVeyggsgzzavAt4i6TGSrvzHIuK50lX9yuXZXkj+obk70ltRRrM82/znJJcg/yfJQPl1o7ntebb5UuA2SQE8AHykZAUXiKS7SNo1LR2v+jRQAxARXyEZv3obsA04DHzwVX3fKP4zYmZmJeBLVWZmNiIODjMzGxEHh5mZjYiDw8zMRsTBYWZmI+LgMBshSQM5T9F9VNLsV3m+RblPpZV0+cmeVmxWDnw7rtkISToYEZNPsk0kf6+Oj+B81wGLI+LGApVolikHh9kIDQ2OtMexCniY5FE0bwM+AVwITCB57tWn030vBD4PTAL6gDcDG9L9uoDb0uXFEXFjeu5lJI8B6QE+GBHPSLoD2A8sBqYDfxERyzNsttkJvlRlNnITci5TfTddNwf454iYFxG/Aj4ZEYuBTpIH6nWmj8C4B/jTiFgI/DZwCPgUcE9ELIqIe4Z81xeBb6QP1/xX4As521qA3+R1O30AAAELSURBVALeAfxdRm01exk/csRs5F6IiEWDH9Jewa/S9xwMeo+kG0j+jrWQvAMigO6IeAQgIvanx5/qu17Pi48A/ybwmZxt30sviT0m6axX0yCzkXBwmBXGocEFSWeTPLr7woh4Pr2sVJfBd/blLJ8yfcwKyZeqzAqvgSRIetOewFvT9VuAlnScA0n1kqqBA0D9Sc71c5IHLwL8PvDTzKo2y5ODw6zAImIdybvaHwe+Bfxnuv4o8F7gi5LWkbwQrI7k9aUd6ZjJe4ec7qPAByWtB94P/GlxWmF2cr6ryszMRsQ9DjMzGxEHh5mZjYiDw8zMRsTBYWZmI+LgMDOzEXFwmJnZiDg4zMxsRP4/5kr0aL+87D4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(*score)\n",
    "plt.xlabel(\"Fraction\")\n",
    "plt.ylabel(\"Accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
